{
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# Licensed to the Apache Software Foundation (ASF) under one\n",
        "# or more contributor license agreements. See the NOTICE file\n",
        "# distributed with this work for additional information\n",
        "# regarding copyright ownership. The ASF licenses this file\n",
        "# to you under the Apache License, Version 2.0 (the\n",
        "# \"License\"); you may not use this file except in compliance\n",
        "# with the License. You may obtain a copy of the License at\n",
        "#\n",
        "#   http://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing,\n",
        "# software distributed under the License is distributed on an\n",
        "# \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY\n",
        "# KIND, either express or implied. See the License for the\n",
        "# specific language governing permissions and limitations\n",
        "# under the License"
      ],
      "metadata": {
        "id": "H-YbtpqChYYo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "a5343d14"
      },
      "source": [
        "# Running Dataflow on TPUs: Quickstart examples\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/apache/beam/blob/master/examples/notebooks/beam-ml/dataflow_tpu_examples.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/colab_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://github.com/apache/beam/blob/master/examples/notebooks/beam-ml/dataflow_tpu_examples.ipynb\"><img src=\"https://raw.githubusercontent.com/google/or-tools/main/tools/github_32px.png\" />View source on GitHub</a>\n",
        "  </td>\n",
        "</table>\n",
        "<br/>\n",
        "<br/>"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "This Colab notebook shows you how to set up two pipelines:\n",
        "1. A pipeline that runs a trivial computation on a TPU.\n",
        "2. A pipeline that runs inference using the [Gemma-3-27b-it model](https://huggingface.co/google/gemma-3-27b-it) on TPUs .\n",
        "\n",
        "Both pipelines use a custom Docker image. The Dataflow jobs will launch using a [Flex Template](https://cloud.google.com/dataflow/docs/guides/templates/using-flex-templates) to allow the same job to be reproduced in different Colab environments."
      ],
      "metadata": {
        "id": "hAm4UpVHimSr"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8L0c_bikJt4d"
      },
      "source": [
        "## Prerequisites"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "i5IAopB4ewpu"
      },
      "source": [
        "First, you need to authenticate to your Google Cloud Project. After running the cell below, you might need to **click on the text prompts in the cell** and enter inputs as prompted.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OdZ5bkvwesGg"
      },
      "outputs": [],
      "source": [
        "import sys\n",
        "if 'google.colab' in sys.modules:\n",
        "    from google.colab import auth\n",
        "    auth.authenticate_user()\n",
        "!gcloud auth login"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "dEFtSATYJp6p"
      },
      "source": [
        "Now, set environment variables to access pipeline resources, such as a\n",
        "Cloud Storage bucket or a repository to host container images in Artifact Registry."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "id": "cMJS0sYBfNkI"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import datetime\n",
        "\n",
        "project_id = \"some-project\" # @param {type:\"string\"}\n",
        "gcs_bucket = \"some-bucket\" # @param {type:\"string\"}\n",
        "ar_repository = \"some-ar-repo\" # @param {type:\"string\"}\n",
        "\n",
        "# Use a region where you have TPU accelerator quota.\n",
        "region = \"some-region1\" # @param {type:\"string\"}\n",
        "!gcloud config set project {project_id}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vIrXayHQL-d6"
      },
      "source": [
        "Enable the necessary APIs if your project hasn't enabled them yet. If you have the appropriate permissions, you can enable the APIs by running the following cell."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_jKxVSK_MBFr"
      },
      "outputs": [],
      "source": [
        "!gcloud services enable \\\n",
        "    dataflow.googleapis.com \\\n",
        "    compute.googleapis.com \\\n",
        "    logging.googleapis.com \\\n",
        "    storage.googleapis.com \\\n",
        "    cloudresourcemanager.googleapis.com \\\n",
        "    artifactregistry.googleapis.com \\\n",
        "    cloudbuild.googleapis.com"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lS3V0Sh5MbtT"
      },
      "source": [
        "Now, you'll create a Cloud Storage bucket and Artifact Registry repository if you don't already have these resources."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "8Wrs8yUhMas7"
      },
      "outputs": [],
      "source": [
        "!gcloud storage buckets describe gs://{gcs_bucket} >/dev/null 2>&1 || gcloud storage buckets create gs://{gcs_bucket} --location={region}\n",
        "!gcloud artifacts repositories describe {ar_repository} --location={region} >/dev/null 2>&1 || gcloud artifacts repositories create {ar_repository} --repository-format=docker --location={region}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Uv12ZxPVcTEc"
      },
      "source": [
        "# Example 1: Minimal computation pipeline using TPU V5E\n",
        "\n",
        "First, create a simple pipeline you can run to verify that TPUs are accessible, your custom Docker image has the necessary dependencies to interact with the TPUs and your Dataflow pipeline launch configuration is valid.\n",
        "\n",
        "With this sample you use the PyTorch library to interact with a TPU device."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "31f4cabb"
      },
      "outputs": [],
      "source": [
        "%%writefile minimal_tpu_pipeline.py\n",
        "from __future__ import annotations\n",
        "import torch\n",
        "import torch_xla\n",
        "import argparse\n",
        "import logging\n",
        "import apache_beam as beam\n",
        "from apache_beam.options.pipeline_options import PipelineOptions\n",
        "\n",
        "\n",
        "class check_tpus(beam.DoFn):\n",
        "    \"\"\"Validates that a TPU is accessible.\"\"\"\n",
        "    def setup(self):\n",
        "        tpu_devices = torch_xla.xm.get_xla_supported_devices()\n",
        "        if not tpu_devices:\n",
        "            raise RuntimeError(\"No TPUs found on the worker.\")\n",
        "        logging.info(f\"Found TPU devices: {tpu_devices}\")\n",
        "        tpu = torch_xla.device()\n",
        "        t1 = torch.randn(3, 3, device=tpu)\n",
        "        t2 = torch.randn(3, 3, device=tpu)\n",
        "        result = t1 + t2\n",
        "        logging.info(f\"Result of a sample TPU computation: {result}\")\n",
        "\n",
        "    def process(self, element):\n",
        "        yield element\n",
        "\n",
        "\n",
        "def run(input_text: str, beam_args: list[str] | None = None) -> None:\n",
        "    beam_options = PipelineOptions(beam_args, save_main_session=True)\n",
        "    pipeline = beam.Pipeline(options=beam_options)\n",
        "    (\n",
        "        pipeline\n",
        "        | \"Create data\" >> beam.Create([input_text])\n",
        "        | \"Check TPU availability\" >> beam.ParDo(check_tpus())\n",
        "        | \"My transform\" >> beam.LogElements(level=logging.INFO)\n",
        "    )\n",
        "    pipeline.run()\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    logging.getLogger().setLevel(logging.INFO)\n",
        "\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\n",
        "        \"--input-text\",\n",
        "        default=\"Hello! This pipeline verified that TPUs are accessible.\",\n",
        "        help=\"Input text to display.\",\n",
        "    )\n",
        "    args, beam_args = parser.parse_known_args()\n",
        "\n",
        "    run(args.input_text, beam_args)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4516f3e0"
      },
      "source": [
        "## Create a Dockerfile for your TPU-compatible container image.\n",
        "\n",
        "In your Dockerfile you configure the environment variables to use with a `V5E` `1x1` TPU device.\n",
        "\n",
        "**You must use the region where you have V5E TPU quota to run this example.**\n",
        "\n",
        "To use a different TPU, adjust the configuration according to the [Dataflow documentation](https://cloud.google.com/dataflow/docs/tpu/use-tpus).\n",
        "\n",
        "This Dockerfile creates an image that serves both as a custom worker image for your Beam pipeline and also as a launcher image for your Flex template."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_EY1_rmXdAM5"
      },
      "outputs": [],
      "source": [
        "%%writefile Dockerfile\n",
        "\n",
        "FROM python:3.11-slim\n",
        "\n",
        "COPY minimal_tpu_pipeline.py minimal_tpu_pipeline.py\n",
        "\n",
        "# Copy the Apache Beam worker dependencies from the Beam Python 3.10 SDK image.\n",
        "COPY --from=apache/beam_python3.10_sdk:2.67.0 /opt/apache/beam /opt/apache/beam\n",
        "\n",
        "# Copy Template Launcher dependencies\n",
        "COPY --from=gcr.io/dataflow-templates-base/python310-template-launcher-base /opt/google/dataflow/python_template_launcher /opt/google/dataflow/python_template_launcher\n",
        "\n",
        "# Install TPU software and Apache Beam SDK\n",
        "RUN pip install --no-cache-dir torch~=2.8.0 torch_xla[tpu]~=2.8.0 apache-beam[gcp]==2.67.0 -f https://storage.googleapis.com/libtpu-releases/index.html\n",
        "\n",
        "# Configuration for v5e 1x1 accelerator type.\n",
        "ENV TPU_CHIPS_PER_HOST_BOUNDS=1,1,1\n",
        "ENV TPU_ACCELERATOR_TYPE=v5litepod-1\n",
        "ENV TPU_SKIP_MDS_QUERY=1\n",
        "ENV TPU_HOST_BOUNDS=1,1,1\n",
        "ENV TPU_WORKER_HOSTNAMES=localhost\n",
        "ENV TPU_WORKER_ID=0\n",
        "\n",
        "ENV FLEX_TEMPLATE_PYTHON_PY_FILE=minimal_tpu_pipeline.py\n",
        "\n",
        "# Set the entrypoint to Apache Beam SDK worker launcher.\n",
        "ENTRYPOINT [ \"/opt/apache/beam/boot\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XBFIEqNmenRj"
      },
      "source": [
        "## Push your Docker image to Artifact Registry."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "F9XQBZfrfbM2"
      },
      "source": [
        "Finally, build your Docker image, and push it in Artifact Registry. This process should take about 15 minutes or so."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UaA-sBC1fabY"
      },
      "outputs": [],
      "source": [
        "container_tag = \"20250801\"\n",
        "container_image = ''.join([\n",
        "    region, \"-docker.pkg.dev/\",\n",
        "    project_id, \"/\",\n",
        "    ar_repository, \"/\",\n",
        "    \"tpu-minimal-example\", \":\", container_tag\n",
        "])\n",
        "\n",
        "!gcloud builds submit --tag {container_image}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1chqESwuSerP"
      },
      "source": [
        "## Build the Dataflow Flex Template."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "I3ukh2lwlmm3"
      },
      "source": [
        "To create a reproducible environment for launching the pipeline, build a Flex Template.\n",
        "\n",
        "First, create a `metadata.json` file to change the default Dataflow worker disk size when launching the template.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GhlCMBnDl8-t"
      },
      "outputs": [],
      "source": [
        "%%writefile metadata.json\n",
        "{\n",
        "    \"name\": \"Minimal TPU Example on Dataflow\",\n",
        "    \"description\": \"A Flex template launching a Dataflow Job doing a TPU computation \",\n",
        "    \"parameters\": [\n",
        "      {\n",
        "        \"name\": \"disk_size_gb\",\n",
        "        \"label\": \"disk_size_gb\",\n",
        "        \"helpText\": \"disk_size_gb for worker\",\n",
        "        \"isOptional\": true\n",
        "      }\n",
        "    ]\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eQAX8rzJVtDS"
      },
      "source": [
        "Run the following cell to build the Flex Template and save it Cloud Storage."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CYLTC-jpSh6j"
      },
      "outputs": [],
      "source": [
        "!gcloud dataflow flex-template build gs://{gcs_bucket}/minimal_tpu_pipeline.json \\\n",
        "  --image {container_image} \\\n",
        "  --sdk-language \"PYTHON\" \\\n",
        "  --metadata-file metadata.json \\\n",
        "  --project {project_id}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cAhW0FdW5W7t"
      },
      "source": [
        "## Submit your pipeline to Dataflow.\n",
        "\n",
        "Since you launch the pipeline as a Flex Template, make the following adjustments to the command line:\n",
        "\n",
        "* Use `--parameters` option to specify the container image and disk size.\n",
        "* Use `--additional-experiments` option to specify the necessary Dataflow service options.\n",
        "* To avoid using more than one process on a TPU simultaneously, limit process-level parallelism with the `no_use_multiple_sdk_containers` experiment."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UVtBBPcWCzFu"
      },
      "outputs": [],
      "source": [
        "!gcloud dataflow flex-template run \"minimal-tpu-example-`date +%Y%m%d-%H%M%S`\" \\\n",
        "  --template-file-gcs-location gs://{gcs_bucket}/minimal_tpu_pipeline.json \\\n",
        "  --region {region} \\\n",
        "  --project {project_id} \\\n",
        "  --temp-location gs://{gcs_bucket}/tmp \\\n",
        "  --parameters sdk_container_image={container_image} \\\n",
        "  --worker-machine-type \"ct5lp-hightpu-1t\" \\\n",
        "  --parameters disk_size_gb=50 \\\n",
        "  --additional-experiments \"worker_accelerator=type:tpu-v5-lite-podslice;topology:1x1\" \\\n",
        "  --additional-experiments \"no_use_multiple_sdk_containers\"\n"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Once the job is launched, use the following link to monitor its status: https://console.cloud.google.com/dataflow/jobs/\n",
        "\n",
        "Sample worker logs for the `Check TPU availability` step look like the following:\n",
        "\n",
        "```\n",
        "Found TPU devices: ['xla:0']\n",
        "Result of a sample TPU computation: tensor([[ 0.3355, -1.4628, -3.2610], [-1.4656, 0.3196, -2.8766], [ 0.8667, -1.5060, 0.7125]], device='xla:0')\n",
        "```"
      ],
      "metadata": {
        "id": "xRW_d_i_tVel"
      }
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DpUAUjDlcMOR"
      },
      "source": [
        "# Example 2: Inference Pipeline with Gemma 3 27B using TPU V6E\n",
        "\n",
        "This example shows you how to perform inference on a TPU using Gemma 3 27b model.\n",
        "\n",
        "To fit this model in TPU memory, you need four V6E TPU chips connected in 2x2 topology.\n",
        "\n",
        "**You must use the region where you have V6E TPU quota to run this example.**\n",
        "\n",
        "The example uses [Apache Beam RunInference APIs](https://beam.apache.org/documentation/transforms/python/elementwise/runinference/) with the [VLLM Completions model handler](https://beam.apache.org/releases/pydoc/current/apache_beam.ml.inference.vllm_inference.html).\n",
        "\n",
        "The model is downloaded from HuggingFace at runtime, and running the example requires a [HuggingFace access token](https://huggingface.co/docs/hub/en/security-tokens).\n",
        "\n",
        "First, create a pipeline file."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "GGCqkzgXda97"
      },
      "outputs": [],
      "source": [
        "%%writefile gemma_tpu_pipeline.py\n",
        "from __future__ import annotations\n",
        "import argparse\n",
        "import logging\n",
        "import apache_beam as beam\n",
        "from apache_beam.ml.inference.base import RunInference\n",
        "from apache_beam.options.pipeline_options import PipelineOptions\n",
        "from apache_beam.ml.inference.vllm_inference import VLLMCompletionsModelHandler\n",
        "\n",
        "\n",
        "def run(input_text: str, beam_args: list[str] | None = None) -> None:\n",
        "    beam_options = PipelineOptions(beam_args, save_main_session=True)\n",
        "    pipeline = beam.Pipeline(options=beam_options)\n",
        "    (\n",
        "        pipeline\n",
        "        | \"Create data\" >> beam.Create([input_text])\n",
        "        | \"Run Inference\" >> RunInference(\n",
        "            model_handler=VLLMCompletionsModelHandler(\n",
        "                'google/gemma-3-27b-it',\n",
        "                {\n",
        "                    'max-model-len': '4096',\n",
        "                    'no-enable-prefix-caching': None,\n",
        "                    'disable-log-requests': None,\n",
        "                    'tensor-parallel-size': '4',\n",
        "                    'limit-mm-per-prompt': '{\"image\": 0}'\n",
        "                })\n",
        "            )\n",
        "        | \"Log Output\" >> beam.LogElements(level=logging.INFO)\n",
        "    )\n",
        "    pipeline.run()\n",
        "\n",
        "\n",
        "if __name__ == \"__main__\":\n",
        "    logging.getLogger().setLevel(logging.INFO)\n",
        "    parser = argparse.ArgumentParser()\n",
        "    parser.add_argument(\n",
        "        \"--input-text\",\n",
        "        default=\"What are TPUs?\",\n",
        "        help=\"Input text query.\",\n",
        "    )\n",
        "    args, beam_args = parser.parse_known_args()\n",
        "    run(args.input_text, beam_args)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "emTd69gUonq-"
      },
      "source": [
        "## Create a new Dockerfile for this pipeline with additional dependencies.\n",
        "Note that this sample uses a different TPU device than the example 1, so the environment variables are different.\n",
        "\n",
        "**You must use your own HuggingFace Token in the Dockerfile.** For instructions on creating a token, see [User access tokens](https://huggingface.co/docs/hub/en/security-tokens)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6elKvBZ0_dc4"
      },
      "outputs": [],
      "source": [
        "%%writefile Dockerfile\n",
        "# Use the official vLLM TPU base image, which has TPU dependencies.\n",
        "# To use the latest version, use: vllm/vllm-tpu:nightly\n",
        "FROM vllm/vllm-tpu:5964069367a7d54c3816ce3faba79e02110cde17\n",
        "\n",
        "# Copy your pipeline file.\n",
        "COPY gemma_tpu_pipeline.py gemma_tpu_pipeline.py\n",
        "\n",
        "# You can use a more recent version of Apache Beam\n",
        "COPY --from=apache/beam_python3.12_sdk:2.67.0 /opt/apache/beam /opt/apache/beam\n",
        "RUN pip install --no-cache-dir apache-beam[gcp]==2.67.0\n",
        "\n",
        "# Copy Template Launcher dependencies\n",
        "COPY --from=gcr.io/dataflow-templates-base/python310-template-launcher-base /opt/google/dataflow/python_template_launcher /opt/google/dataflow/python_template_launcher\n",
        "\n",
        "# Replace the Hugginface token here.\n",
        "RUN python -c 'from huggingface_hub import HfFolder; HfFolder.save_token(\"YOUR HUGGINGFACE TOKEN\")'\n",
        "\n",
        "# TPU environment variables.\n",
        "ENV TPU_SKIP_MDS_QUERY=1\n",
        "\n",
        "# Configuration for v6e 2x2 accelerator type.\n",
        "ENV TPU_HOST_BOUNDS=1,1,1\n",
        "ENV TPU_CHIPS_PER_HOST_BOUNDS=2,2,1\n",
        "ENV TPU_ACCELERATOR_TYPE=v6e-4\n",
        "ENV VLLM_USE_V1=1\n",
        "\n",
        "ENV FLEX_TEMPLATE_PYTHON_PY_FILE=gemma_tpu_pipeline.py\n",
        "\n",
        "# Set the entrypoint to Apache Beam SDK worker launcher.\n",
        "ENTRYPOINT [ \"/opt/apache/beam/boot\"]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "2V1PmAf1otG4"
      },
      "source": [
        "Run the following cell to build the Docker image and push it to Artifact Registry. This process should take 15 min or so."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": true,
        "id": "He5WkUAE_pYp"
      },
      "outputs": [],
      "source": [
        "container_tag = \"20250801\"\n",
        "container_image = ''.join([\n",
        "    region, \"-docker.pkg.dev/\",\n",
        "    project_id, \"/\",\n",
        "    ar_repository, \"/\",\n",
        "    \"tpu-run-inference-example\", \":\", container_tag\n",
        "])\n",
        "!gcloud builds submit --tag {container_image}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EyYuSgDudcVK"
      },
      "source": [
        "## Build the Flex Template for this pipeline."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "33V96JFAL_jk"
      },
      "source": [
        "To create a reproducible environment for launching the pipeline, build a Flex Template.\n",
        "\n",
        "First, create a `metadata.json` file to change the default Dataflow worker disk size when launching the template."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "L8hylI64L_jl"
      },
      "outputs": [],
      "source": [
        "%%writefile metadata.json\n",
        "{\n",
        "    \"name\": \"Gemma 3 27b Run Inference pipeline with VLLM\",\n",
        "    \"description\": \"A template for Dataflow RunInference pipeline with VLLM in a TPU-enabled environment with VLLM\",\n",
        "    \"parameters\": [\n",
        "      {\n",
        "        \"name\": \"disk_size_gb\",\n",
        "        \"label\": \"disk_size_gb\",\n",
        "        \"helpText\": \"disk_size_gb for worker\",\n",
        "        \"isOptional\": true\n",
        "      }\n",
        "    ]\n",
        "}"
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Run the following cell to build the Flex Template and save it in Cloud Storage."
      ],
      "metadata": {
        "id": "yGRhrD1J2IIW"
      }
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Hvs2JWNydiBl"
      },
      "outputs": [],
      "source": [
        "!gcloud dataflow flex-template build gs://{gcs_bucket}/gemma_tpu_pipeline.json \\\n",
        "  --image {container_image} \\\n",
        "  --sdk-language \"PYTHON\" \\\n",
        "  --metadata-file metadata.json \\\n",
        "  --project {project_id}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VWWaf4cmdi7z"
      },
      "source": [
        "## Finally, submit the job to Dataflow.\n",
        "\n",
        "Since you launch the pipeline as a Flex Template, you are making the following adjustments to the command line:\n",
        "\n",
        "* Use the `--parameters` option to specify the container image and disk size\n",
        "* Use the `--additional-experiments` option to specify the necessary Dataflow service options.\n",
        "* The VLLMCompletionsModelHandler from Beam RunInference APIs only loads the model onto TPUs from a single process. Still, limit the intra-worker parallelism by reducing the value of\n",
        "`--number_of_worker_harness_threads`, which achieves better performance.\n",
        "\n",
        "Once the job is launched, use the following link to monitor its status: https://console.cloud.google.com/dataflow/jobs/"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "OUX0E0XzdlLW"
      },
      "outputs": [],
      "source": [
        "!gcloud dataflow flex-template run \"gemma-tpu-example-`date +%Y%m%d-%H%M%S`\" \\\n",
        "  --template-file-gcs-location gs://{gcs_bucket}/gemma_tpu_pipeline.json \\\n",
        "  --region {region} \\\n",
        "  --project {project_id} \\\n",
        "  --temp-location gs://{gcs_bucket}/tmp \\\n",
        "  --parameters number_of_worker_harness_threads=100 \\\n",
        "  --parameters sdk_container_image={container_image} \\\n",
        "  --parameters disk_size_gb=100 \\\n",
        "  --worker-machine-type \"ct6e-standard-4t\" \\\n",
        "  --additional-experiments \"worker_accelerator=type:tpu-v6e-slice;topology:2x2\""
      ]
    },
    {
      "cell_type": "markdown",
      "source": [
        "Due to model loading and initialization time, the pipeline takes 25 min or so to complete.\n",
        "\n",
        "Sample worker logs for the `Run Inference` step look like the following:\n",
        "\n",
        "```\n",
        "PredictionResult(example='What are TPUs?', inference=Completion(id='cmpl-57ebbddeb1c04dc0a8a74f2b60d10f67', choices=[CompletionChoice(finish_reason='length', index=0, logprobs=None, text='\\n\\nTensor Processing Units (TPUs) are custom-developed AI accelerator ASICs', stop_reason=None, prompt_logprobs=None)], created=1755614936, model='google/gemma-3-27b-it', object='text_completion', system_fingerprint=None, usage=CompletionUsage(completion_tokens=16, prompt_tokens=6, total_tokens=22, completion_tokens_details=None, prompt_tokens_details=None), service_tier=None, kv_transfer_params=None), model_id=None)\n",
        "```"
      ],
      "metadata": {
        "id": "1kpeVbdczt8u"
      }
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}